"""
Base class for data generator.
"""
import h5py
import sys
import numpy as np
import os
import torch as th
import json
from copy import deepcopy
from omnigibson.envs import DataMimicWrapper
import random
import omnigibson.utils.transform_utils as T
class DataGenerator(object):
    """
    The main data generator object that loads a source dataset, parses it, and 
    generates new trajectories.
    """
    def __init__(
        self,
        task_spec,
        dataset_paths,
    ):
        """
        Args:
            task_spec (MG_TaskSpec instance): task specification that will be
                used to generate data
            dataset_path (str): path to hdf5 dataset to use for generation
            demo_keys (list of str): list of demonstration keys to use
                in file. If not provided, all demonstration keys will be
                used.
        """
       
        self.task_spec = task_spec
        self.dataset_path = dataset_paths
        self.files_path = []

        self._load_dataset(dataset_paths=dataset_paths)

    def _load_dataset(self, dataset_paths):
        """
        Load important information from a dataset into internal memory.
        """
        for dataset_path in dataset_paths:
            print("\nDataGenerator: loading dataset at path {}...".format(dataset_path))
            # 筛选出不含final、filter和replay的文件
            all_files = os.listdir(dataset_path)
            filtered_files = [os.path.join(dataset_path, f) for f in all_files if "final" not in f and "filter" not in f and "replay" not in f]
            self.files_path.extend(filtered_files)
        print(f"\nDataGenerator: done loading, found {len(self.files_path)} files after filtering\n")
    
    def h5py_group_to_torch(self, group):
        state = {}
        for key, value in group.items():
            if isinstance(value, h5py.Group):
                state[key] = self.h5py_group_to_torch(value)
            else:
                state[key] = th.tensor(value[()], dtype=th.float32)
        return state

    def select_source_demo(
        self, obj_type=None
    ):
        if obj_type == None:
            self.src_demo_path = np.random.choice(self.files_path)
            self.f = h5py.File(self.src_demo_path, "r")
            print(f"\nDataGenerator: selected source demo {self.src_demo_path}\n")
        else:
            random.shuffle(self.files_path)
            for f in self.files_path:
                try:
                    self.f = h5py.File(f, "r")
                except Exception as e:
                    error_msg = f"Error loading file {f}: {e}"
                    print(error_msg)
                    
                    # 将错误信息输出到外部txt文件
                    error_log_path = "path/to/error_log.txt"
                    with open(error_log_path, "a") as error_log:
                        error_log.write(f"{error_msg}\n")
                    continue
                if obj_type in self.f["data"].attrs["info"]:
                    self.src_demo_path = f
                    #self.src_demo_path = os.path.join(self.dataset_path, self.src_demo)
                    break
            
            print(f"\nDataGenerator: selected source demo {self.src_demo_path}\n")

    def select_source_demos(self, obj_type=None):
        self.src_demos = []
        if obj_type is not None:
        # 如果obj_type是字典，提取category_object字段
            if isinstance(obj_type, dict) and 'category_object' in obj_type:
                obj_type = obj_type['category_object']
            
            for f in self.files_path:
                try:
                    self.f = h5py.File(f, "r")
                except Exception as e:
                    error_msg = f"Error loading file {f}: {e}"
                    print(error_msg)
                    # 将错误信息输出到外部txt文件
                    # error_log_path = "/mnt/57f3be37-4da8-4ec8-ac65-db0a6f77322e/data_src_0416/hdf5_check_log.txt"
                    # with open(error_log_path, "a") as error_log:
                    #     error_log.write(f"{error_msg}\n")
                    continue
                
                print(f)
                
                # 获取info属性
                if "info" in self.f["data"].attrs:
                    info_attr = self.f["data"].attrs["info"]
                    
                    # 确保我们在处理字符串
                    if isinstance(info_attr, str) and isinstance(obj_type, str):
                        # 两者都是字符串，可以直接检查
                        if obj_type in info_attr:
                            self.src_demos.append(f)
                    elif isinstance(obj_type, str):
                        # obj_type是字符串，info_attr可能是其他类型
                        if obj_type in str(info_attr):
                            self.src_demos.append(f)
            
                self.f.close()
        else:
            self.src_demos = self.files_path

    def select_source_demo_in_list(self, idx):
        if len(self.src_demos) == 0:
           print("没有可用的源演示数据")
           return False
        else:
            self.src_demo_path = self.src_demos[idx % len(self.src_demos)]
            print(f" ===============当前回放selected source demo================= {self.src_demo_path}")
            #self.src_demo_path = os.path.join(self.dataset_path, self.src_demo)
            self.f = h5py.File(self.src_demo_path, "r")
            return True
    
    def extract_action(self, action):
        arm_left_command = action[0:6]
        arm_right_command = action[6:12]
        hand_left_command = action[12:23]
        hand_right_command = action[23:34]
        return arm_left_command, arm_right_command, hand_left_command, hand_right_command

    def move_to_initial_position(self, env, initial_pos, step=60):
        
        for i in range(step):
            # print(f"move to initial position {i}")  # Maximum 100 steps
            action = self.generate_action(initial_pos)
            obs, _, _, _, _ = env.env.step(action)
        for i in range(step):
            # print(f"move to initial position {i}")  # Maximum 100 steps
            action = self.generate_action(initial_pos)
            obs, _, _, _, _ = env.env.step(action)
    def generate_action(self, initial_pos=th.tensor([0.3, 0.5, 1.4])):
        arm_left_command = th.tensor([initial_pos[0], initial_pos[1], initial_pos[2], -0.579228, 0.4055798, -0.579228, 0.4055798])
        arm_right_command = th.tensor([initial_pos[0], -initial_pos[1], initial_pos[2], -0.579228, 0.4055798, -0.579228, 0.4055798])
        arm_left_angleaxis = T.quat2axisangle(arm_left_command[3:][[1,2,3,0]])
        arm_right_angleaxis = T.quat2axisangle(arm_right_command[3:][[1,2,3,0]])

        arm_left_action  = th.cat((arm_left_command[0:3],arm_left_angleaxis),0)
        arm_right_action  = th.cat((arm_right_command[0:3],arm_right_angleaxis),0)

        hand_left_command = th.tensor([1.57,0.64,0.04,3.11,1.57,3.07,1.57,3.08,1.57,3.05,1.57])
        hand_right_command = th.tensor([1.57,0.64,0.04,3.11,1.57,3.07,1.57,3.08,1.57,3.05,1.57])
        hand_action  = th.cat((hand_left_command,hand_right_command),0).cpu()
        action = th.cat((arm_left_action,arm_right_action,hand_action),0)
        return action
    def analyze_mask(self, mask):
        """
        分析mask中的不同值及其索引范围
        
        Args:
            mask: 数据掩码张量
            
        Returns:
            dict: 不同mask值及其索引范围的字典，格式为 {值: [(开始索引, 结束索引), ...]}
        """
        mask_values = th.unique(mask)
        mask_dict = {}
        
        for value in mask_values:
            # 找出当前值的所有索引
            indices = th.where(mask == value)[0].cpu().numpy()
            
            if len(indices) == 0:
                continue
                
            # 将索引组织成范围
            ranges = []
            start_idx = indices[0]
            prev_idx = indices[0]
            
            for i in range(1, len(indices)):
                # 如果当前索引与前一个不连续，则形成一个范围
                if indices[i] > prev_idx + 1:
                    ranges.append((int(start_idx), int(prev_idx)))
                    start_idx = indices[i]
                prev_idx = indices[i]
            
            # 添加最后一个范围
            ranges.append((int(start_idx), int(prev_idx)))
            
            # 存储到字典
            mask_dict[value.item()] = ranges
        
        return mask_dict
    
    def transform_action(self, action, cur_object_pose, src_object_pose, is_lift=True):
        """
        根据当前物体和源物体的位置，对动作进行变换
        """
        # 计算物体之间的相对位置
        relative_pose = th.tensor(cur_object_pose) - th.tensor(src_object_pose)

        if action.dim() == 2:
            transformed_action = action.squeeze(0)
        else:
            transformed_action = action

        arm_left_command, arm_right_command, hand_left_command, hand_right_command = self.extract_action(transformed_action)
        left_to_right = arm_left_command[0:3] - arm_right_command[0:3]
        
        
        arm_right_command[0:2] = arm_right_command[0:2] + relative_pose[0:2]
        if not is_lift:
            arm_left_command[0:2] = arm_left_command[0:2] + relative_pose[0:2]
        else:
            arm_left_command[0:3] = arm_right_command[0:3] + left_to_right
        transformed_action = th.cat((arm_left_command, arm_right_command, hand_left_command, hand_right_command), 0)
        # print(action)
        # input("Press Enter to continue...")
        
        return transformed_action
    
    def get_objs_state_info(self,env, info):

        def get_obj_state_info(obj):
            bbox_center_in_world, bbox_quat_in_world, bbox_extent_in_base_frame, bbox_center_in_desired_frame = obj.get_base_aligned_bbox(
            visual=False
            )
            linear_velocity=obj.get_linear_velocity()
            angular_velocity=obj.get_angular_velocity()
            obj_state_info={
                "bbox_center_in_world":bbox_center_in_world,
                "bbox_quat_in_world":bbox_quat_in_world,
                "linear_velocity":linear_velocity,
                "angular_velocity":angular_velocity
            }
            obj_state_data=th.concat([bbox_center_in_world, bbox_quat_in_world,linear_velocity,angular_velocity],dim=0)
            return obj_state_data,obj_state_info


        ball=env.scene.object_registry("name", "ball0")
        
        ball_state_data,ball_state_info=get_obj_state_info(ball)

        task_stage_data = th.tensor([0.0])

        # 拼接所有状态数据
        state_tensor = th.cat([ball_state_data, task_stage_data], dim=0)
        return state_tensor
    
    def generate(
        self,
        env
    ):
        # sample new task instance
        flag = False
        data_grp = self.f["data"]
        src_scene_file = json.loads(data_grp.attrs["scene_file"])
        traj_grp_action = data_grp[f"demo_{0}/action"] #34
        traj_grp_mask = data_grp[f"demo_{0}/data_mask"]
        # traj_grp = self.h5py_group_to_torch(traj_grp)
        action = th.tensor(traj_grp_action[()], dtype=th.float32)
        mask = th.tensor(traj_grp_mask[()], dtype=th.float32)

        mask_dict = self.analyze_mask(mask)
        self.move_to_initial_position(env, initial_pos=self.task_spec.init_pose, step=45)

        current_obs = None
        next_obs,_ = env.env.get_obs()

        for subtask_ind in range(len(self.task_spec)):
            mask_ind = subtask_ind + 1          # 跳过无效的零
            env.set_data_mask(mask_ind)
            # name of object for this subtask
            subtask_object_name = self.task_spec[subtask_ind]["object_ref"]
            subtask_object = env.scene.object_registry("name", subtask_object_name)
            
            cur_object_pose = subtask_object.get_position_orientation()[0]
            src_object_pose = src_scene_file["state"]["object_registry"][subtask_object_name]["root_link"]["pos"]
            
            assert len(mask_dict[mask_ind]) == 1, "mask_dict[subtask_ind] should have only one range"

            start_idx, end_idx = mask_dict[mask_ind][0]
            
            
            for i, a in enumerate(action[start_idx:end_idx]):
                current_obs = deepcopy(next_obs)
                #a = self.transform_action(a, cur_object_pose, src_object_pose)
                next_obs, reward, terminated, truncated, info = env.env.step(action=a, n_render_iterations=1)
                
                if terminated:
                    flag = True
                step_data = env._parse_step_data(
                    action=a,
                    current_obs=current_obs,
                    reward=reward,
                    terminated=terminated,
                    truncated=truncated,
                    info=info,
                    obj_state=self.get_objs_state_info(env, info)
                )
                env.current_traj_history.append(step_data)
                env.step_count += 1
                

            if subtask_ind == len(self.task_spec) - 1 and not flag:
                for i in range(100):
                    current_obs = deepcopy(next_obs)
                    a = self.transform_action(action[end_idx], cur_object_pose, src_object_pose)
                    next_obs, reward, terminated, truncated, info = env.env.step(action=a, n_render_iterations=1)
                    if terminated:
                        flag = True
                    step_data = env._parse_step_data(
                        action=a,
                        current_obs=current_obs,
                        reward=reward,
                        terminated=terminated,
                        truncated=truncated,
                        info=info,
                        obj_state=self.get_objs_state_info(env, info)
                    )
                    env.current_traj_history.append(step_data)
                    env.step_count += 1
                    if flag:
                        break

        if flag:
            env.save_data()
            print(f"Successfully generated a new demonstration, save to hdf5 file {env.output_path}")
        else:
            env.step_count -= len(env.current_traj_history)
            env.current_traj_history = []
            os.remove(env.output_path)
            print(f"Failed to generate a new demonstration, remove the generated data {env.output_path}")
        self.f.close()
        return flag



